library(recipes)
library(data.table)
library(tidyverse)

eol_cleaning_pipeline <- function(.data) {
  # Main data cleaning pipeline
  # Can be run after some subsampling of data
  (.data 
   %>% drop_dead_on_index_date()
   %>% drop_young_patients()
   %>% add_grouped_cancer_type()
   # put preprocessing that can be done on the SQL side above collect() function
   %>% collect()
   %>% preprocess_column_types()
   %>% add_id_columns()
   %>% fix_missings()
   %>% add_age_quantiles()
  )
}

prepare_meta_data <- function(meta) {
  # prepares metadata by assinging types to the columns
  (meta 
   %>% mutate(feature.type = case_when(
                 str_detect(feature_name, "BT_.*_sw|MED_sw") ~ 'binary',
                 str_detect(feature_name, "BT_.*days_from|BT_.*val_num|MED_num|MED_days_from_last|UTL_l365d_total_c") ~ 'continuous',
                 str_detect(feature_name, "BT_.*val_ord|BT_.*val_cat") ~ 'continuous',
                 TRUE ~ feature.type
             ),
             feature.name = case_when(
                 feature_name=="UTL_l365d_hospUnplanned_admDays" ~ "UTL_l365d_hospUnplanned_admDay",
                 feature_name=="UTL_f365d_hospUnplanned_admDays" ~ "UTL_f365d_hospUnplanned_admDay",
                 TRUE ~ feature_name)
             )
  )
}

preprocess_column_types <- function(eol_data) {
  # prepares columns according to the types from metadata file 
  meta <- load_meta_data()
  cols <- colnames(eol_data)
  (meta 
    %>% right_join(as_tibble(cols), by = c("feature_name" = "value"))
    %>% prepare_meta_data()
    -> meta
  )
  # collect colnames by types into named list
  (vars2preprocess <- 
      c('categorical', 'binary', 'continuous', 'date') 
      %>% map(~vars(filter(meta, feature.type == .x)$feature_name))
      %>% set_names(c('categorical', 'binary', 'continuous', 'date'))
  )
  # use values from list to preprocess columns
  (eol_data
    %>% mutate_at(vars2preprocess$categorical, factor)
    %>% mutate_at(vars2preprocess$binary, as.factor)
    %>% mutate_at(vars2preprocess$continuous, as.numeric)
    %>% mutate_at(vars2preprocess$date, as.Date)
  )
}

fix_missings <- function(.data) {
  # turn NA in factors into a new "missing" level
  (.data
    %>% mutate_if(is.factor, forcats::fct_explicit_na)
  )
}

drop_dead_on_index_date <- function(.data) {
  # drop patients that are dead on index date
  filter(.data, S_index_date_XX != DMG_date_of_death_XX | is.na(DMG_date_of_death_XX))
}

drop_young_patients <- function(.data) {
  # subset young patients
  filter(.data, DMG_age >= 25)
}

add_grouped_cancer_type <- function(.data) {
  # find top 20 most common cancer types and turn other types into "Other" level
  top_20_cancer <- get_top_cancer(con, 20) %>% pluck(1)
  (.data
    %>% mutate(CNRS_topo_main_groups = if_else(CNRS_topo_grouped_desc %in% top_20_cancer,
                                               CNRS_topo_grouped_desc,
                                               "Others"),
               # shorten long types
               CNRS_topo_main_groups = case_when(
                 CNRS_topo_main_groups == "Secondary and unspecified malignant neoplasm of Lymph nodes" ~ "Lymph nodes (secondary?)",
                 CNRS_topo_main_groups == "Hematopoietic and reticuloendothelial systems" ~ "Hemato. and reticul. systems",
                 TRUE ~ CNRS_topo_main_groups))
  )
}

get_top_cancer <- function(con, top_n) {
  # find top N most common cancer groups
 (get_eol_table_names()$main
    %>% get_z_table(con, .)
    %>% filter(S_sample_source_XX == "cnr_data")
    %>% group_by(CNRS_topo_grouped_desc)
    %>% summarise(n = count())
    %>% arrange(desc(n))
    %>% head(top_n)
    %>% collect()
 )
}

add_age_quantiles <- function(.data) {
  # add age quintile groups
  (.data
     %>% mutate(DMG_age_quintiles = gtools::quantcut(DMG_age, seq(0, 1, 0.2), right = FALSE, 
                                                     labels = paste0(as.character(seq(1, 5, 1)), "_age_quin")))
  )
}

add_id_columns <- function(.data) {
  # add id columns for groups
  (.data
    %>% mutate(id = group_indices(., id_var),
               ishpuz_code = group_indices(., S_zihui_ishpuz_bikur_XX),
               obs_uniq_ident = group_indices(., id_var, S_index_date_XX, S_sample_source_XX))
  )
}


### --- preparation for modeling

make_matrix_recipe <- function(data, outcome){
  # create base recipe roles
  (data 
   %>% recipe(DMG_died_within_365d ~ .)
   %>% update_role(contains("_XX"),                    
                   contains('hashed_id'),
                   contains("date"),                   
                   contains("UTL_f"),                  
                   contains("f1y"),                    
                   contains("died"),                   
                   contains("desc"),                   
                   contains("hativaD_|wardTypeD"),     
                   contains("before_death"),           
                   contains("f365d"),                  
                   contains("ADM_num_admissions_f"),   
                   contains("CNRS_num_of_cancers_diag"),
                   contains("UTL_365d_beforeDeath_tot"),
                   contains("taarich_ishpuz"),
                   new_role = "non-predictor") 
   %>% update_role(index_date_month, index_date_year, new_role = "predictor")
   %>% update_role({{ outcome }}, new_role = "outcome")
   -> recipe_draft
  )
  #select relevant cols to preprocess ased on roles
  relevant_cols <- summary(recipe_draft) %>% filter(role == 'predictor') %>% pluck("variable")

  # prepare spec with list of features for preprocessing
  spec <- make_spec(data, relevant_cols)
  # create recipe with preprocess based on features
  (recipe_draft
    %>% step_rm(has_role("non-predictor"))
    # remove features with 99% of the same value
    %>% step_rm(!!spec$features_to_drop)
    # turn to sw those with less than 5% non missing
    %>% step_mutate_at(!!spec$features_to_sw, fn = function(x) if_else(is.na(x), 0,1))
    # turn features with less then 10 unique values to factors
    %>% step_mutate_at(!!spec$features_to_factor, fn = function(x) as.factor(x)
                                                                   %>% fct_explicit_na(., "missing"))
    %>% step_other(!!spec$features_to_factor, threshold = 0.01)
    # discretize features with more then 10 unique values
    %>% step_discretize(!!spec$features_to_cut, num_breaks = 5, min_unique=2, options = list(na.rm=T))
    %>% step_dummy(all_nominal(), -all_outcomes())
  )
  
}

make_spec <- function(data, features_list) {
  # creates specification for features that need to be dealed before modeling
  features_to_drop <- get_features_to_drop(data, features_list)
  features_list <- setdiff(features_list, features_to_drop)
  features_list <- setdiff(features_list, data %>% select_if(~!is.numeric(.x)) %>% colnames())
  
  features_to_sw <- get_features_to_sw(data, features_list)
  features_list <- setdiff(features_list, features_to_sw)
  
  features_to_factor <- get_features_to_factor(data, features_list)
  features_to_cut <-  get_features_to_cut(data, features_list)
  
  list(
    features_to_drop   = features_to_drop,
    features_to_sw     = features_to_sw,
    features_to_factor = features_to_factor,
    features_to_cut    = features_to_cut
  )
}

get_features_to_drop <- function(.data, features_list) {
  # find features with 99% of the same value (including NAs)
  .data <- as.data.table(.data)
  temp <- sapply(.data[, c(features_list), with = F],
                 function(x) {sort(prop.table(table(x, useNA = "ifany")), 
                                   decreasing = TRUE)[1]})
  features_to_drop_99p_constOrNa <- 
    gsub("\\..*", "", names(temp[which(temp > 0.99)]))
  return(features_to_drop_99p_constOrNa)
}

get_features_to_sw <- function(data, features_list) {
  # find features that have less than 5% non missing and should be turned to binary switch
  (get_features_with_na_count(data, features_list)
   %>% filter(na_count > 0.95*nrow(data))
   %>% pluck("names")
  )
}

get_features_to_factor <- function(data, featurues_list) {
  # find features with less than 10 unique values that should be turned into factors
  features_with_na <- get_features_with_na_count(data, featurues_list)
  (data
    %>% select(features_with_na$names)
    %>% summarise_all(data.table::uniqueN)
    %>% pivot_longer(everything(), names_to = "names", values_to = "unique_values")
    %>% filter(unique_values <= 10)
    %>% pluck("names")
  )
}

get_features_to_cut <- function(data, features_list) {
  # find features with more than 10 unique values that should be discretized
  features_with_na <- get_features_with_na_count(data, features_list)
  (data
    %>% select(features_with_na$names)
    %>% summarise_all(data.table::uniqueN)
    %>% pivot_longer(everything(), names_to = "names", values_to = "unique_values")
    %>% filter(unique_values > 10)
    %>% pluck("names")
  )
}

get_features_with_na_count <- function(.data, features_list) {
  # helper function that helps find features with NAs
  (.data 
   %>% select(features_list)
   %>% select_if(~any(is.na(.x)))
   %>% summarise_all(~sum(is.na(.x)))
   %>% pivot_longer(everything(), names_to = "names", values_to = "na_count")
  )
}
